{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concise Implementation of Linear Regression\n",
    ":label:`sec_linear_djl`\n",
    "\n",
    "Broad and intense interest in deep learning for the past several years\n",
    "has inspired both companies, academics, and hobbyists\n",
    "to develop a variety of mature open source frameworks\n",
    "for automating the repetitive work of implementing\n",
    "gradient-based learning algorithms.\n",
    "In the previous section, we relied only on\n",
    "(i) `NDArray` for data storage and linear algebra;\n",
    "and (ii) `GradientCollector` for calculating derivatives.\n",
    "In practice, because data iterators, loss functions, optimizers,\n",
    "and neural network layers (and some whole architectures)\n",
    "are so common, modern libraries implement these components for us as well.\n",
    "\n",
    "In this section, we will show you how to implement\n",
    "the linear regression model from :numref:`sec_linear_scratch`\n",
    "concisely by using DJL.\n",
    "\n",
    "## Generating the Dataset\n",
    "\n",
    "To start, we will generate the same dataset as in the previous section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.10.0\n",
    "%maven ai.djl:model-zoo:0.8.0\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "%maven net.java.dev.jna:jna:5.3.0\n",
    "    \n",
    "%maven ai.djl.mxnet:mxnet-engine:0.10.0\n",
    "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/DataPoints.java\n",
    "%load ../utils/Training.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.Model;\n",
    "import ai.djl.metric.Metrics;\n",
    "import ai.djl.ndarray.NDArray;\n",
    "import ai.djl.ndarray.NDManager;\n",
    "import ai.djl.ndarray.types.Shape;\n",
    "import ai.djl.nn.Block;\n",
    "import ai.djl.nn.ParameterList;\n",
    "import ai.djl.nn.SequentialBlock;\n",
    "import ai.djl.nn.core.Linear;\n",
    "import ai.djl.training.DefaultTrainingConfig;\n",
    "import ai.djl.training.EasyTrain;\n",
    "import ai.djl.training.Trainer;\n",
    "import ai.djl.training.dataset.ArrayDataset;\n",
    "import ai.djl.training.dataset.Batch;\n",
    "import ai.djl.training.listener.TrainingListener;\n",
    "import ai.djl.training.loss.Loss;\n",
    "import ai.djl.training.optimizer.Optimizer;\n",
    "import ai.djl.training.tracker.Tracker;\n",
    "import ai.djl.translate.TranslateException;\n",
    "\n",
    "import java.io.IOException;\n",
    "import java.nio.file.*;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "NDArray trueW = manager.create(new float[]{2, -3.4f});\n",
    "float trueB = 4.2f;\n",
    "\n",
    "DataPoints dp = DataPoints.syntheticData(manager, trueW, trueB, 1000);\n",
    "NDArray features = dp.getX();\n",
    "NDArray labels = dp.getY();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reading the Dataset\n",
    "\n",
    "Just like in the last section,\n",
    "we can call upon DJL's `dataset` package to read data.\n",
    "The first step will be to instantiate an `ArrayDataset`.\n",
    "Here, we set the `features` and `labels` as parameters.\n",
    "We also specify a `batchSize`\n",
    "and specify a Boolean value `shuffle` indicating whether or not\n",
    "we want the `ArrayDataset` to randomly sample the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "// Saved in the utils file for later use\n",
    "public ArrayDataset loadArray(NDArray features, NDArray labels, int batchSize, boolean shuffle) {\n",
    "    return new ArrayDataset.Builder()\n",
    "                  .setData(features) // set the features\n",
    "                  .optLabels(labels) // set the labels\n",
    "                  .setSampling(batchSize, shuffle) // set the batch size and random sampling\n",
    "                  .build();\n",
    "}\n",
    "\n",
    "int batchSize = 10;\n",
    "ArrayDataset dataset = loadArray(features, labels, batchSize, false);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To verify that it is working, we can read and print\n",
    "the first minibatch of instances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "for (Batch batch : dataset.getData(manager)) {\n",
    "    NDArray X = batch.getData().head();\n",
    "    NDArray y = batch.getLabels().head();\n",
    "    System.out.println(X);\n",
    "    System.out.println(y);\n",
    "    batch.close();\n",
    "    break;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the Model\n",
    "\n",
    "When we implemented linear regression from scratch\n",
    "(in :numref:`sec_linear_scratch`),\n",
    "we defined our model parameters explicitly\n",
    "and coded up the calculations to produce output\n",
    "using basic linear algebra operations.\n",
    "You *should* know how to do this.\n",
    "But once your models get more complex,\n",
    "and once you have to do this nearly every day,\n",
    "you will be glad for the assistance.\n",
    "The situation is similar to coding up your own blog from scratch.\n",
    "Doing it once or twice is rewarding and instructive,\n",
    "but you would be a lousy web developer\n",
    "if every time you needed a blog you spent a month\n",
    "reinventing the wheel.\n",
    "\n",
    "For standard operations, we can use DJL's predefined blocks,\n",
    "which allow us to focus especially\n",
    "on the layers used to construct the model\n",
    "rather than having to focus on the implementation.\n",
    "To define a linear model, we first import the `Model` class,\n",
    "which defines a lot of useful methods to interact with our `model`.\n",
    "We will first define a model variable `model`.\n",
    "We will then instantiate a SequentialBlock variable `net`, which will refer to an instance of the `SequentialBlock` class. The `SequentialBlock` class defines a container for several layers that will be chained together. Given input data, a `SequentialBlock` passes it through the first layer, in turn passing the output as the second layer’s input and so forth. In the following example, our model consists of only one layer, so we do not really need `SequentialBlock`. But since nearly all of our future models will involve multiple layers, we will use it anyway just to familiarize you with the most standard workflow."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Recall the architecture of a single-layer network as shown in :numref:`fig_singleneuron`.\n",
    "The layer is said to be *fully-connected*\n",
    "because each of its inputs are connected to each of its outputs\n",
    "by means of a matrix-vector multiplication.\n",
    "In DJL, we can use a `Linear` block to apply a linear transformation.\n",
    "We simply set the number of outputs (in our case its set to 1) and choose\n",
    "if we want to include a bias(yes).\n",
    "\n",
    "![Linear regression is a single-layer neural network. ](https://raw.githubusercontent.com/awslabs/djl-resources/311169c9dbd89e1597a5ae6b0b1ba8a402b8b55e/jupyter/d2l-java/img/singleneuron.svg)\n",
    ":label:`fig_singleneuron`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [],
   "source": [
    "Model model = Model.newInstance(\"lin-reg\");\n",
    "\n",
    "SequentialBlock net = new SequentialBlock();\n",
    "Linear linearBlock = Linear.builder().optBias(true).setUnits(1).build();\n",
    "net.add(linearBlock);\n",
    "\n",
    "model.setBlock(net);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the Loss Function\n",
    "\n",
    "In DJL, the `Loss` class defines various loss functions.\n",
    "We will use the imported class `Loss`.\n",
    "In this example, we will use the DJL\n",
    "implementation of squared loss (`L2Loss`).\n",
    "\n",
    "$$\n",
    "L2Loss = \\sum_{i = 1}^{n}(y_i - \\hat{y_i})^2\n",
    "$$\n",
    "\n",
    "L2 Loss or 'Mean Squared Error' is the sum of the squared\n",
    "difference between the true `y` value and the predicted `y`\n",
    "value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "8"
    }
   },
   "outputs": [],
   "source": [
    "Loss l2loss = Loss.l2Loss();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the Optimization Algorithm\n",
    "\n",
    "Minibatch SGD and related variants\n",
    "are standard tools for optimizing neural networks\n",
    "and thus DJL supports SGD alongside a number of\n",
    "variations on this algorithm through its `Optimizer` class.\n",
    "When we instantiate the `Optimizer`,\n",
    "we will specify the optimization algorithm we wish to use (`sgd`).\n",
    "We can also manually set hyper-parameters.\n",
    "SGD just requires `learningRate`,\n",
    "here we set it to a fixed rate of 0.03."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "9"
    }
   },
   "outputs": [],
   "source": [
    "Tracker lrt = Tracker.fixed(0.03f);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instantiate Configuration and Trainer\n",
    "Now we'll create a training configuration that\n",
    "describes how we want to train our model.\n",
    "We will then initialize a `trainer` to do the\n",
    "training for us."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DefaultTrainingConfig config = new DefaultTrainingConfig(l2loss)\n",
    "    .optOptimizer(sgd) // Optimizer (loss function)\n",
    "    .optDevices(Device.getDevices(1)) // single GPU\n",
    "    .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing Model Parameters\n",
    "\n",
    "Before training our model, we need to initialize the model parameters,\n",
    "such as the weights and biases in the linear regression model.\n",
    "We simply call the `initialize` function with the shape of the model\n",
    "we are training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "7"
    }
   },
   "outputs": [],
   "source": [
    "// First axis is batch size - won't impact parameter initialization\n",
    "// Second axis is the input size\n",
    "trainer.initialize(new Shape(batchSize, 2));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics\n",
    "Normally, DJL doesn't record metrics unless explicitly told to\n",
    "as recording metrics impacts the execution flow optimizations.\n",
    "To record metrics, we must instantiate `metrics` from outside\n",
    "the `trainer` object and then pass it in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Metrics metrics = new Metrics();\n",
    "trainer.setMetrics(metrics);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "You might have noticed that expressing our model through DJL\n",
    "requires comparatively few lines of code.\n",
    "We did not have to individually allocate parameters,\n",
    "define our loss function, or implement stochastic gradient descent.\n",
    "Once we start working with much more complex models,\n",
    "DJL's advantages will grow considerably.\n",
    "However, once we have all the basic pieces in place,\n",
    "the training loop itself is strikingly similar\n",
    "to what we did when implementing everything from scratch.\n",
    "\n",
    "To refresh your memory: for some number of epochs,\n",
    "we will make a complete pass over the dataset (train_data),\n",
    "iteratively grabbing one minibatch of inputs\n",
    "and the corresponding ground-truth labels.\n",
    "For each minibatch, we go through the following ritual:\n",
    "\n",
    "* Generate predictions, calculate loss, and calculate gradients by calling `trainBatch(batch)` (forward pass and backward pass).\n",
    "* Update the model parameters by invoking the `step` function.\n",
    "\n",
    "`Logging` will automatically print the evaluators we are watching\n",
    "during each epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "10"
    }
   },
   "outputs": [],
   "source": [
    "int numEpochs = 3;\n",
    "\n",
    "for (int epoch = 1; epoch <= numEpochs; epoch++) {\n",
    "    System.out.printf(\"Epoch %d\\n\", epoch);\n",
    "    // Iterate over dataset\n",
    "    for (Batch batch : trainer.iterateDataset(dataset)) {\n",
    "        // Update loss and evaulator\n",
    "        EasyTrain.trainBatch(trainer, batch);\n",
    "        \n",
    "        // Update parameters\n",
    "        trainer.step();\n",
    "        \n",
    "        batch.close();\n",
    "    }\n",
    "    // reset training and validation evaluators at end of epoch\n",
    "    trainer.notifyListeners(listener -> listener.onEpoch(trainer));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below, we compare the model parameters learned by training on finite data\n",
    "and the actual parameters that generated our dataset.\n",
    "To access parameters with DJL,\n",
    "we first access the layer that we need from `model`\n",
    "and then access that layer's weight and bias through its parameter list\n",
    "by calling `getParameters()`.\n",
    "We then simply get each param with `valueAt()`.\n",
    "Here, `valueAt(0)` and `valueAt(1)` returns the weights and bias respectively.\n",
    "As in our from-scratch implementation,\n",
    "note that our estimated parameters are\n",
    "close to their ground truth counterparts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "12"
    }
   },
   "outputs": [],
   "source": [
    "Block layer = model.getBlock();\n",
    "ParameterList params = layer.getParameters();\n",
    "NDArray wParam = params.valueAt(0).getArray();\n",
    "NDArray bParam = params.valueAt(1).getArray();\n",
    "\n",
    "float[] w = trueW.sub(wParam.reshape(trueW.getShape())).toFloatArray();\n",
    "System.out.printf(\"Error in estimating w: [%f %f]\\n\", w[0], w[1]);\n",
    "System.out.printf(\"Error in estimating b: %f\\n\", trueB - bParam.getFloat());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving Your Model\n",
    "Now that you have trained your model, you probably want to save it\n",
    "for future use. Additionally, you probably also want to add\n",
    "metadata such as training accuracy and epochs trained.\n",
    "You can do this easily. Simply point to a file location with `Paths.get`.\n",
    "Metadata can be saved with the `setProperty` method.\n",
    "Then call the `save` method on the model to save it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path modelDir = Paths.get(\"../models/lin-reg\");\n",
    "Files.createDirectories(modelDir);\n",
    "\n",
    "model.setProperty(\"Epoch\", Integer.toString(numEpochs)); // save epochs trained as metadata\n",
    "\n",
    "model.save(modelDir, \"lin-reg\");\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "* Using DJL, we can implement models much more succinctly.\n",
    "* In DJL, the `training.dataset` package provides tools for data processing, the `nn` package defines a large number of neural network layers, and the `Loss` class defines many common loss functions.\n",
    "* DJL's `training.initializer` package provides various methods for model parameter initialization.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Review the DJL documentation to see what loss functions and initialization methods are provided in the class `Loss` and `Trainer`. Replace the loss with L1 Loss.\n",
    "1. How do you access the parameters during training?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "11.0.10+9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
